import torch

x = torch.tensor([
    [[1, 2, 3],
     [4, 5, 6]],
    [[7, 8, 9],
     [0, 1, 2]]
])

print("x.shape =", x.shape)
print(x)

idx = torch.argmax(x)
print("最大值索引(整体):", idx)
print("最大值:", x.flatten()[idx])

idx = torch.argmax(x, dim=0)
print("dim=0 求 argmax：")
print(idx)

idx = torch.argmax(x, dim=1)
print("dim=1 求 argmax：")
print(idx)

val, idx = torch.max(x, dim=2)
print("最大值：\n", val)
print("索引：\n", idx)
